# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
# This source file is part of the Cangjie project, licensed under Apache-2.0
# with Runtime Library Exception.
# 
# See https://cangjie-lang.cn/pages/LICENSE for license information.

from os import path
import random
import itertools

integer_types = ['Int8', 'Int16', 'Int32', 'Int64', 'IntNative', 'UInt8', 'UInt16', 'UInt32', 'UInt64', 'UIntNative']
number_types = integer_types + ['Float16', 'Float32', 'Float64']
integral_types = number_types + ['String', 'Rune', 'Bool', 'Unit', 'Object']
default_value_map = {'Bool' : 'true', 'Unit' : '()', 'Rune' : "'1'", 'Object' : 'Object()', 'String' : '"1"', 'C' : 'C()',
                    'Nothing' : '{ => var n : Nothing; do { n = break } while(false); return n }()'}
ops = ['+', '-', '*', '**', '/', '%', '&', '|', '^', '<<', '>>']

def default_value(ty : str) -> str:
  if 'Native' in ty: return ty + '(1)'
  if ty in integer_types: return '1' + ty[0].lower() + ty[ty.index('t')+1:]
  if 'Float' in ty: return '1.0f' + ty[5:]
  if ty in classes or ty in impls: return ty+'()'
  if ty in interfaces: return 'C%s.%s()' % (ty, ty.lower())
  return default_value_map[ty]

dir = path.dirname(path.realpath(__file__))
path = dir + '/test_' + path.basename(dir) + '_{}.cj'
template = '''
/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

/*
  @Assertion:   4.23(27) 3. The return type of the overloaded operator needs to be the same as or a subtype of the left
                operand, i.e., as for a, b, op in the expression a op=b, they need to pass the type check of
                a = a op b.
  @Description: %s
  @Mode: %s
  @Negative: %s
  @Structure: single
  @CompileWarning: %s
  @Comment: Auto-generated by gen.py%s
*/
%s
main() {
    %s
    return 0
}
'''

random.seed(123)
limit = 100
classes = { 'A' : 'JI', 'B' : 'AJI', 'C' : 'BAJI', 'D' : 'AJI'}
interfaces = { 'I' : '', 'J' : 'I', 'K' : 'JI', 'L' : 'I'}
impls = { 'CI' : 'I', 'CJ' : 'JI', 'CK' : 'KJI', 'CL' : 'LI'}
custom_type_dict_list = [classes, interfaces, impls]
custom_type_dict = {**classes, **interfaces, **impls}
custom_types = [t for c in custom_type_dict_list for t in c]
types = integral_types + custom_types

subtypes = [(t, t) for t in integral_types] + [('Nothing', t) for t in types]
non_subtypes = random.sample([(t0, t1, op) for t0, t1 in itertools.product(integral_types, integral_types) if t0 != t1 for op in ops], limit)

def subtype(t0, t1):
  if isinstance(t0, tuple): return all([subtype(t0[i], t1[i]) for i in range(len(t0))])
  return t0 == t1 or any(t0 in c and t1 in c[t0] for c in [classes, interfaces, impls])

def product(s1, s2):
  return list(itertools.product(s1, s2))

def warning(ty : str):
  return 'ignore' if ty in ['Unit', 'Nothing'] else 'no'

def issue(ty : str):
  return '\n  @Issue: 6518' if ty == 'Nothing' else ''

def write_counted(contents : str):
  global counter
  with open(path.format(str(counter).zfill(3)), 'w') as file:
    file.write(contents)
    counter += 1

def keyword(ty : str):
  if ty in interfaces: return 'interface'
  if ty in custom_types: return 'open class'
  return 'extend'

def name(ty : str, c : list):
  if ty in integral_types: return ty
  return ty + ((' <: ' + custom_type_dict[ty][0]) if (custom_type_dict[ty] and custom_type_dict[ty][0] in c) else '')

op_class = '    public operator func %s(x : Unit) { %s }'
op_interface = '    operator func %s(x : Unit) : %s'

def decl(t0 : str, t1 : str, op : str = ''):
  t1_is_in = t1 in interfaces
  c = [t1] if t1 in integral_types else []
  if t0 in classes or t1 in classes: c += list(classes.keys())
  if t0 in impls or t1 in impls or t0 in interfaces or t1 in interfaces: c += list(interfaces.keys()) + list(impls.keys())
  template = '''
%s %s {%s}
'''
  return ''.join([template % (keyword(t), name(t, c),
    ('\n%s\n' % '\n'.join([op_interface % (op, t0) if t == t1 and t1_is_in else op_class % (op, default_value(t0))
                           for op in ([op] if op else ops)]) +
        (('    static func %s() : %s { %s() }\n' % (t[1].lower(), t[1], t)) if t in impls and subtype(t, t1) else ''))
        if (t == t1 or (t1_is_in and t not in interfaces and all(t in interfaces for t in custom_type_dict[t]))) else '') for t in c])

for c in [product(interfaces, interfaces), product(classes, classes), product(impls, interfaces), product(classes, interfaces)]:
    subtypes += [(t0, t1) for t0, t1 in c if subtype(t0, t1)]
    non_subtypes += [(t0, t1, op) for t0, t1 in c if not subtype(t0, t1) for op in ops]
  
counter = 2
positive_tests = [template % ('''Checks that if overloaded operator op for type {t1} with argument type T
                has return type {t2}, which is a subtype of {t1},
                then compound assignment operator op= can be used with types {t1} and T.'''.format(t1=t1, t2=t0),
                'run', 'no', warning(t0), issue(t0), decl(t0, t1), '''var a = %s
    %s
''' % (default_value(t1), '\n    '.join(['a %s= ()' % op for op in ops]))) for t0, t1 in subtypes]
negative_tests = [template % ('''Checks that if overloaded operator {op} for type {t1} with argument type T has 
                return type {t0}, then compound assignment operator {op}= cannot be
                used with types {t1} and T.'''.format(t1=t1, t0=t0, op=op), 'compileonly', 'yes', warning(t0), issue(t0), decl(t0, t1, op),
'''var a = %s
    a %s= ()
''' % (default_value(t1), op)) for t0, t1, op in non_subtypes]
for test in positive_tests + negative_tests:
  write_counted(test)
